import torch
from torch import nn
from torch.nn import functional as F

from torch import distributed as dist
from image_synthesis.modeling.codecs.base_codec import BaseCodec
from image_synthesis.modeling.utils.misc import get_token_type

def get_world_size():
    if not dist.is_available():
        return 1

    if not dist.is_initialized():
        return 1

    return dist.get_world_size()


def all_reduce(tensor, op=dist.ReduceOp.SUM):
    world_size = get_world_size()

    if world_size == 1:
        return tensor

    dist.all_reduce(tensor, op=op)

    return tensor

################### quantize facilities #############################
class Quantize(nn.Module):
    def __init__(self, dim, n_embed, decay=0.99, eps=1e-5):
        super().__init__()

        self.dim = dim
        self.n_embed = n_embed
        self.decay = decay
        self.eps = eps
        # import pdb; pdb.set_trace()

        embed = torch.randn(dim, n_embed)
        self.register_buffer("embed", embed)
        self.register_buffer("cluster_size", torch.zeros(n_embed))
        self.register_buffer("embed_avg", embed.clone())

    def forward(self, input):
        flatten = input.reshape(-1, self.dim)
        dists = (
            flatten.pow(2).sum(1, keepdim=True)
            - 2 * flatten @ self.embed
            + self.embed.pow(2).sum(0, keepdim=True)
        )
        _, embed_ind = (-dists).max(1)
        embed_onehot = F.one_hot(embed_ind, self.n_embed).type(flatten.dtype)
        embed_ind = embed_ind.view(*input.shape[:-1])
        quantize = self.embed_code(embed_ind)

        if self.training:
            # import pdb; pdb.set_trace()
            embed_onehot_sum = embed_onehot.sum(0)
            embed_sum = flatten.transpose(0, 1) @ embed_onehot

            all_reduce(embed_onehot_sum)
            all_reduce(embed_sum)

            self.cluster_size.data.mul_(self.decay).add_(
                embed_onehot_sum, alpha=1 - self.decay
            )
            self.embed_avg.data.mul_(self.decay).add_(embed_sum, alpha=1 - self.decay)
            n = self.cluster_size.sum()
            cluster_size = (
                (self.cluster_size + self.eps) / (n + self.n_embed * self.eps) * n
            )
            embed_normalized = self.embed_avg / cluster_size.unsqueeze(0)
            self.embed.data.copy_(embed_normalized)
            # print((self.embed > 1.0e-20).abs().sum())


        diff = (quantize.detach() - input).pow(2).mean()
        quantize = input + (quantize - input).detach()

        return quantize, diff, embed_ind

    def embed_code(self, embed_id):
        return F.embedding(embed_id, self.embed.transpose(0, 1))

def Quantize_inference(embed, x, only_dist=False):
    # x is the [n, h, w, c] and embed is nn.embed with [n_embed, embed_dim]
    n, h, w, c = x.shape
    n_embed, embed_dim = embed.weight.shape
    flatten = x.reshape(-1, c)
    dists = (
        flatten.pow(2).sum(1, keepdim=True)
        - 2 * flatten @ embed.weight.transpose(0, 1)
        + embed.weight.transpose(0, 1).pow(2).sum(0, keepdim=True)
    )

    if only_dist:
        dist2 = dists.view(n, h, w, n_embed).permute(0, 3, 1, 2).contiguous() # [N, n_embed, H, W]
        return dist2
    
    _, embed_ind = (-dists).max(1)
    embed_onehot = F.one_hot(embed_ind, n_embed).type(flatten.dtype)
    embed_ind = embed_ind.view(*x.shape[:-1])
    quantize = embed(embed_ind)
    # quantize = F.embedding(embed_ind, embed.weight)

    diff = (quantize.detach() - x).pow(2).mean()
    quantize = x + (quantize - x).detach()

    return quantize, diff, embed_ind

def Quantize_inference_norm(embed, x, only_dist=False):
    # x is the [n, h, w, c] and embed is nn.embed with [n_embed, embed_dim]
    n, h, w, c = x.shape
    x = F.normalize(x, p = 2, dim = -1)
    embed_m = F.normalize(embed.weight, p = 2, dim = -1)
    sim = torch.einsum('n h w d, j d -> n h w j', x, embed_m) # [N, H, W, N_embed]
    sim = sim.permute(0, 3, 1, 2).contiguous() # [N, n_embed, H, W]
    if only_dist:
        return sim
    embed_ind = sim.argmax(dim=1)
    quantize = embed(embed_ind)
    diff = (quantize.detach() - x).pow(2).mean()
    quantize = x + (quantize - x).detach()
    return quantize, diff, embed_ind

#################### blocks for encoder and decoder ###################
# vqvae encoder and decorder
class ResBlock(nn.Module):
    def __init__(self, in_channel, channel):
        super().__init__()

        self.conv = nn.Sequential(
            nn.ReLU(inplace=False),
            nn.Conv2d(in_channel, channel, 3, padding=1),
            nn.ReLU(inplace=False),
            nn.Conv2d(channel, in_channel, 1),
        )

    def forward(self, input):
        out = self.conv(input)
        out += input

        return out

# blocks like ResNet
def make_norm_layer(inplanes, norm_layer):
    if norm_layer == 'bn':
        norm = nn.BatchNorm2d(inplanes) 
    elif norm_layer == 'gn':
        norm = nn.GroupNorm(num_groups=32, num_channels=inplanes, eps=1e-6, affine=True)
    else:
        norm = None
    return norm

class ResNetBasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1,
                 base_width=64, norm_layer='none', block_type='downsample',
                 sample_type='conv'):
        super(ResNetBasicBlock, self).__init__()
        assert norm_layer in ['none', 'bn', 'gn']
        assert stride in [1, 2]
        self.stride = stride
        self.inplanes = inplanes

        if block_type == 'downsample':
            # conv 1
            if stride != 1: # downsample by 2 
                self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=4, stride=2, padding=1)
            else: # keep the same when stride == 1
                self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=1, padding=1)
            self.norm1 = make_norm_layer(planes, norm_layer)

            # conv2 
            self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1,
                                    padding=1, groups=1, bias=False, dilation=1)
            self.norm2 = make_norm_layer(planes, norm_layer)
        else:
            # conv1 
            if stride != 1: # upsample
                if sample_type == 'conv':
                    self.conv1 = nn.ConvTranspose2d(inplanes, inplanes, kernel_size=4, stride=2, padding=1)
                else:
                    self.conv1 = nn.Sequential(
                        nn.Upsample(scale_factor=2, mode='nearest'),
                        nn.Conv2d(inplanes, inplanes, kernel_size=3, stride=1, padding=1)
                    )
            else: # keep the same when stride == 1
                self.conv1 = nn.Conv2d(inplanes, inplanes, kernel_size=3, stride=1, padding=1)
            self.norm1 = make_norm_layer(inplanes, norm_layer)
            
            # conv2 
            self.conv2 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=1,
                                    padding=1, groups=1, bias=False, dilation=1)
            self.norm2 = make_norm_layer(planes, norm_layer)

        self.relu = nn.ReLU(inplace=False)#TODOTrue)
        # make downsample
        if block_type == 'downsample':
            if stride != 1:
                if sample_type == 'conv':
                    layers_ = [
                        nn.Conv2d(inplanes, planes * self.expansion, kernel_size=4, stride=2, padding=1)
                    ]
                else:
                    layers_ = [
                        nn.Upsample(scale_factor=0.5, mode='blinear'),
                        nn.Conv2d(inplanes, planes * self.expansion, kernel_size=1, stride=1)
                    ]
            elif inplanes != planes * self.expansion:
                 layers_ = [
                     nn.Conv2d(inplanes, planes * self.expansion, kernel_size=1, stride=1)
                 ]
            else:
                layers_ = []
            norm = make_norm_layer(planes * self.expansion, norm_layer)
        else:
            if stride != 1:
                if sample_type == 'conv':
                    layers_ = [
                        nn.ConvTranspose2d(inplanes * self.expansion, planes, kernel_size=4, stride=2, padding=1)
                    ]
                else:
                    layers_ = [
                        nn.Upsample(scale_factor=2, mode='nearest'),
                        nn.Conv2d(inplanes * self.expansion, planes, kernel_size=1, stride=1)
                    ]
            elif inplanes * self.expansion != planes:
                layers_ = [
                        nn.Conv2d(inplanes * self.expansion, planes, kernel_size=1, stride=1)
                    ]
            else:
                layers_ = []
            
            norm = make_norm_layer(planes, norm_layer)
        if norm is not None:
            layers_.append(norm)
        if len(layers_) > 0:
            self.downsample = nn.Sequential(*layers_)
        else:
            self.downsample = None


    def forward(self, x):
        identity = x

        out = self.conv1(x)
        if self.norm1 is not None:
            out = self.norm1(out)
        out = self.relu(out)

        out = self.conv2(out)
        if self.norm2 is not None:
            out = self.norm2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out

class ResNetBottleneck(nn.Module):

    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None, 
                 norm_layer='none', block_type='downsample',
                 sample_type='conv'):
        super(ResNetBottleneck, self).__init__()
        assert norm_layer in ['none', 'bn', 'gn']
        assert stride in [1, 2]
        self.block_type = block_type
        self.stride = stride
        self.inplanes = inplanes

        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
        if block_type == 'downsample':
            # conv1
            self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=1)
            self.norm1 = make_norm_layer(planes, norm_layer)

            # conv2
            if stride != 1:
                self.conv2 = nn.Conv2d(planes, planes, kernel_size=4, stride=2, padding=1)
            else:
                self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1)
            self.norm2 = make_norm_layer(planes, norm_layer)
            
            # conv3
            self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, stride=1)
            self.norm3 = make_norm_layer(planes * self.expansion, norm_layer)
        else:
            # conv1 
            self.conv1 = nn.Conv2d(inplanes * self.expansion, inplanes, kernel_size=1, stride=1)
            self.norm1 = make_norm_layer(inplanes, norm_layer)   

            # conv2 
            if stride != 1:
                if sample_type == 'conv':
                    self.conv2 = nn.ConvTranspose2d(inplanes, inplanes, kernel_size=4, stride=2, padding=1)
                else:
                    self.conv2 = nn.Sequential(
                        nn.Upsample(scale_factor=2, mode='nearest'),
                        nn.Conv2d(inplanes, inplanes, kernel_size=3, stride=1, padding=1) 
                    )
            else:
                self.conv2 = nn.Conv2d(inplanes, inplanes, kernel_size=3, stride=1, padding=1)
            self.norm2 = make_norm_layer(inplanes, norm_layer)

            # conv3
            self.conv3 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=1)
            self.norm3 = make_norm_layer(planes, norm_layer)

        self.relu = nn.ReLU(inplace=False)
        
        # make downsample
        if block_type == 'downsample':
            if stride != 1:
                if sample_type == 'conv':
                    layers_ = [
                        nn.Conv2d(inplanes, planes * self.expansion, kernel_size=4, stride=2, padding=1)
                    ]
                else:
                    layers_ = [
                        nn.Upsample(scale_factor=0.5, mode='blinear'),
                        nn.Conv2d(inplanes, planes * self.expansion, kernel_size=1, stride=1)
                    ]
            elif inplanes != planes * self.expansion:
                 layers_ = [
                     nn.Conv2d(inplanes, planes * self.expansion, kernel_size=1, stride=1)
                 ]
            else:
                layers_ = []
            norm = make_norm_layer(planes * self.expansion, norm_layer)
        else:
            if stride != 1:
                if sample_type == 'conv':
                    layers_ = [
                        nn.ConvTranspose2d(inplanes * self.expansion, planes, kernel_size=4, stride=2, padding=1)
                    ]
                else:
                    layers_ = [
                        nn.Upsample(scale_factor=2, mode='nearest'),
                        nn.Conv2d(inplanes * self.expansion, planes, kernel_size=1, stride=1)
                    ]
            elif inplanes * self.expansion != planes:
                layers_ = [
                        nn.Conv2d(inplanes * self.expansion, planes, kernel_size=1, stride=1)
                    ]
            else:
                layers_ = []
            
            norm = make_norm_layer(planes, norm_layer)
        if norm is not None:
            layers_.append(norm)
        if len(layers_) > 0:
            self.downsample = nn.Sequential(*layers_)
        else:
            self.downsample = None
    

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        if self.norm1 is not None:
            out = self.norm1(out)
        out = self.relu(out)

        out = self.conv2(out)
        if self.norm2 is not None:
            out = self.norm2(out)
        out = self.relu(out)

        out = self.conv3(out)
        if self.norm3 is not None:
            out = self.norm3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out

class ResNetEncoder(nn.Module):
    layer = {
        'resnet18': [2, 2, 2, 2],
        'resnet34': [3, 4, 6, 3],
        'resnet50': [3, 4, 6, 3],
        'resnet101': [3, 4, 23, 3],
    }
    block = {
        'resnet18': ResNetBasicBlock,
        'resnet34': ResNetBasicBlock,
        'resnet50': ResNetBottleneck,
        'resnet101': ResNetBottleneck,
    }

    def __init__(
        self,
        body_type='resnet18',
        norm_layer='none',
        sample_type='conv',
        hidden_planes=256,
        layers = None # list or None
        ):
        super().__init__()
    
        self.inplanes = 64
        self.norm_layer = norm_layer
        self.sample_type = sample_type

        # first conv layer consample by 2
        layer_in = [
            nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1)
        ]
        norm = make_norm_layer(self.inplanes, norm_layer)
        if norm is not None:
            layer_in.append(norm)
        layer_in.append(nn.ReLU(inplace=False))
        self.layer_in = nn.Sequential(*layer_in)

        # make layers for backbone
        layers = self.layer[body_type] if layers is None else layers
        block = self.block[body_type]
        self.num_layers = len(layers)
        for i in range(self.num_layers):
            planes = 64 * 2 ** (i)
            stride = 2 if i < self.num_layers - 1 else 1 # downsampling except the last layer
            layer_ = self._make_layer(block=block, layer=layers[i], planes=planes, stride=stride) 
            setattr(self, 'layer{}'.format(i+1), layer_)

        # make layers for output
        self.layer_out = nn.Sequential(
            ResBlock(self.inplanes, hidden_planes),
            nn.ReLU(inplace=False)
        )
        self.out_channels = self.inplanes

    def _make_layer(self, block, layer, planes, stride):
        layers_ = []
        layers_.append(
            block(self.inplanes, planes, 
                  stride=stride, 
                  sample_type=self.sample_type,
                  norm_layer=self.norm_layer,
                  block_type='downsample')
        )
        self.inplanes = planes * block.expansion
        for _ in range(1, layer):
            layers_.append(
                block(self.inplanes, planes,
                      stride=1,
                      sample_type=self.sample_type,
                      norm_layer=self.norm_layer,
                      block_type='downsample')
            )
        return nn.Sequential(*layers_)

    def forward(self, x):
        x = self.layer_in(x)
        for i in range(self.num_layers):
            # print('encoder', i)
            x = getattr(self, 'layer{}'.format(i+1))(x)
        x = self.layer_out(x)
        return x

class ResNetDecoder(nn.Module):
    layer = {
        'resnet18': [2, 2, 2, 2],
        'resnet34': [3, 4, 6, 3],
        'resnet50': [3, 4, 6, 3],
        'resnet101': [3, 4, 23, 3],
    }
    block = {
        'resnet18': ResNetBasicBlock,
        'resnet34': ResNetBasicBlock,
        'resnet50': ResNetBottleneck,
        'resnet101': ResNetBottleneck,
    }

    def __init__(
        self,
        body_type='resnet18',
        norm_layer='none',
        sample_type='conv',
        hidden_planes=256,
        layers = None # list or None
        ):
        super().__init__()
        
        
        self.norm_layer = norm_layer
        self.sample_type = sample_type
        
        layers = self.layer[body_type] if layers is None else layers
        block = self.block[body_type]
        self.num_layers = len(layers)
        self.inplanes = 64 * 2 ** (self.num_layers-1) * block.expansion
        self.in_channels = self.inplanes
        


        # make layers for output
        self.layer_out = nn.Sequential(
            nn.ReLU(inplace=False),
            ResBlock(self.inplanes, hidden_planes),
            nn.ReLU(inplace=False)
        )
        # make layers for backbone
        for i in range(self.num_layers-1, -1, -1):
            planes = 64 * 2 ** (i)
            stride = 2 if i < self.num_layers - 1 else 1 # upsample except the last layer
            layer_ = self._make_layer(block=block, layer=layers[i], planes=planes, stride=stride) 
            setattr(self, 'layer{}'.format(i+1), layer_)

        # last conv layer upsample by 2
        layer_in = [
            nn.Conv2d(self.inplanes, 3, kernel_size=3, stride=1, padding=1)
        ]
        self.layer_in = nn.Sequential(*layer_in)

    def _make_layer(self, block, layer, planes, stride):
        layers_ = []
        
        for _ in range(0, layer-1):
            layers_.append(
                block(
                    planes, self.inplanes, 
                    # self.inplanes, planes,
                      stride=1,
                      sample_type=self.sample_type,
                      norm_layer=self.norm_layer,
                      block_type='upsample')
            )
        self.inplanes = (planes // 2) * block.expansion
        layers_.append(
            block(
                planes, self.inplanes, 
                # self.inplanes, planes, 
                  stride=stride, 
                  sample_type=self.sample_type,
                  norm_layer=self.norm_layer,
                  block_type='upsample')
        )
        return nn.Sequential(*layers_)

    def forward(self, x):
        # in the reverse direction of encoder
        x = self.layer_out(x)
        for i in range(self.num_layers-1, -1, -1):
            # print('decoder', i)
            x = getattr(self, 'layer{}'.format(i+1))(x) # stride = 2^(i+1)
        x = self.layer_in(x)
        return x


class VQVAE(nn.Module):
    def __init__(
        self,
        embed_dim=512,
        n_embed=512,
        decay=0.99,
        eps=1.0e-5,
        body_type='resnet18',
        norm_layer='none',
        sample_type='conv',
        hidden_planes=256,
        layers = None # list or None
    ):
        super().__init__()
        self.embed_dim = embed_dim 
        self.n_embed = n_embed
        
        self.enc = ResNetEncoder(
            body_type=body_type,
            norm_layer=norm_layer,
            sample_type=sample_type,
            hidden_planes=hidden_planes,
            layers = layers
        )
        self.quantize_en_conv = nn.Conv2d(self.enc.out_channels, embed_dim, 1)
        self.quantize = Quantize(embed_dim, n_embed, decay=decay, eps=eps)
        self.dec = ResNetDecoder(
            body_type=body_type,
            norm_layer=norm_layer,
            sample_type=sample_type,
            hidden_planes=hidden_planes,
            layers = layers
        )
        self.quantize_de_conv = nn.Conv2d(embed_dim, self.dec.in_channels, 1)

    def encode(self, input):
        enc = self.enc(input)
        
        quant = self.quantize_en_conv(enc).permute(0, 2, 3, 1) # [n, h, w, c]
        
        quant, diff, idx = self.quantize(quant) # [n, h, w, c]
        
        quant = quant.permute(0, 3, 1, 2)
        # diff = diff.unsqueeze(0)
        # 
        return quant, diff, idx
        
    def decode(self, quant):
        dec = self.quantize_de_conv(quant)
        dec = self.dec(dec)
        return dec

class ResNetVQVAE(BaseCodec):
    def __init__(
        self,
        embed_dim=512,
        n_embed=512,
        token_shape=[16, 16],
        trainable=True,
        decay=0.99,
        eps=1e-5,
        body_type='resnet34',
        norm_layer='none',
        sample_type='conv',
        hidden_planes=256,
        layers = None, # list or None
        ckpt_path=None,
    ):
        super().__init__()
        self.vqvae = VQVAE(
            embed_dim=embed_dim,
            n_embed=n_embed,
            body_type=body_type,
            norm_layer=norm_layer,
            sample_type=sample_type,
            hidden_planes=hidden_planes,
            layers = layers
        )
        
        self.criterion = nn.MSELoss()
        if ckpt_path is not None:
            self.init_from_ckpt(ckpt_path)

        self.token_shape = tuple(token_shape) #None
        self.trainable = trainable            
        self._set_trainable()

    def init_from_ckpt(self, path, ignore_keys=list()):
        sd = torch.load(path, map_location="cpu")["model"]
        keys = list(sd.keys())
        for k in keys:
            for ik in ignore_keys:
                if k.startswith(ik):
                    print("ResNetVQVAE deleting key {} from state_dict.".format(k))
                    del sd[k]
        self.load_state_dict(sd, strict=False)
        print("ResNetVQVAE restored from {}".format(path))
    
    @property
    def device(self):
        return self.vqvae.quantize_de_conv.weight.device

    def pre_process(self, data):
        data = data.to(self.device)
        data = data / 255.0
        data = (data - 0.5) / 0.5
        return data

    def post_process(self, data):
        data = (data * 0.5 + 0.5) * 255.0
        data = torch.clamp(data, min=0.0, max=255.0)
        return data
    
    def forward(self, batch, return_loss=True, **kwargs):
        img = self.pre_process(batch['image'])
        quant, latent_loss, latent_idx = self.vqvae.encode(img)
        dec = self.vqvae.decode(quant)
        recon_loss = self.criterion(dec, img)
        loss = recon_loss + 0.25 * latent_loss
        output = {
            'latent_loss': latent_loss,
            'rec_loss': recon_loss,
            'loss': loss
        }

        return output


    @torch.no_grad()
    def sample(self, batch):
        quant, latent_loss, latent_idx = self.vqvae.encode(self.pre_process(batch['image']))
        dec = self.vqvae.decode(quant)
        return {'input': batch['image'], 'reconstruction': self.post_process(dec)}


    def get_tokens(self, img, mask=None, **kwargs):
        img = self.pre_process(img)
        _, _, token = self.vqvae.encode(img) # N x H x W

        if mask is not None: # mask should be B x 1 x H x W
            _, _, token_mask = self.vqvae.encode(img * mask.to(img)) # N x H x W
            # downsampling
            # mask = F.interpolate(mask.float(), size=token.shape[-2:]).to(torch.bool)
            token_type = get_token_type(mask, self.token_shape) # B x 1 x H x W
            mask = ~(token_type != 0)
            output = {
                'target': token.view(token.shape[0], -1),
                'mask': mask.view(mask.shape[0], -1),
                'token': token_mask.view(token_mask.shape[0], -1),
                'token_type': token_type.view(token_type.shape[0], -1),
            }
        else:
            output = {'token': token.view(token.shape[0], -1)}

        return output

    def decode(self, token):
        assert self.token_shape is not None
        token = token.view(-1, *self.token_shape)
        quant = self.vqvae.quantize.embed_code(token)
        quant = quant.permute(0, 3, 1, 2)
        dec = self.vqvae.decode(quant)

        return self.post_process(dec)
    
    def get_rec_loss(self, input, rec):
        if input.max() > 1:
            input = self.pre_process(input)
        if rec.max() > 1:
            rec = self.pre_process(rec)

        rec_loss = self.criterion(rec, input)
        return rec_loss
